/********************************************************************************
 * Copyright 2009 The Robotics Group, The Maersk Mc-Kinney Moller Institute,
 * Faculty of Engineering, University of Southern Denmark
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 ********************************************************************************/

#include "EAA.hpp"

#include "Math.hpp"

#include <rw/common/InputArchive.hpp>
#include <rw/common/OutputArchive.hpp>

using namespace rw::math;

namespace {
template< class T > Vector3D< T > angleAxis (const Rotation3D< T >& R)
{
    typedef Vector3D< T > V;

    const T eps = static_cast< T > (1e-6);

    T cos_theta = (T) 0.5 * (R (0, 0) + R (1, 1) + R (2, 2) - 1);

    // Numerical rounding errors force us to make this check:
    if (cos_theta > 1)
        cos_theta = 1;
    else if (cos_theta < -1)
        cos_theta = -1;

    // ... because otherwise this often yields NaN:
    const T angle = acos (cos_theta);
    const V axis (R (2, 1) - R (1, 2), R (0, 2) - R (2, 0), R (1, 0) - R (0, 1));

    if (fabs (angle) < eps) {    // Is angle close to 0 degree
        return 0.5 * axis;
    }

    if (fabs (angle - Pi) < eps) {    // Is the angle close to 180 degree
        V values ((T) 0.5 * (R (0, 0) + (T) 1.0),
                  (T) 0.5 * (R (1, 1) + (T) 1.0),
                  (T) 0.5 * (R (2, 2) + (T) 1.0));
        values[0] = (values[0] < 0) ? 0 : sqrt (values[0]);
        values[1] = (values[1] < 0) ? 0 : sqrt (values[1]);
        values[2] = (values[2] < 0) ? 0 : sqrt (values[2]);
        // Find index the axis element with largest value.
        std::size_t k = 0;
        if (std::fabs (axis[1]) > std::fabs (axis[0]))
            k = 1;
        if (std::fabs (axis[2]) > std::fabs (axis[1]) && std::fabs (axis[2]) > std::fabs (axis[0]))
            k = 2;
        const int sign_k = (axis[k] >= 0) ? 1 : -1;
        values *= static_cast< T > (sign_k);
        // Determine signs
        if (k == 0) {
            T v1v2 = R (0, 1) + R (1, 0);
            T v1v3 = R (0, 2) + R (2, 0);
            if (v1v2 < 0.)
                values[1] = -values[1];
            if (v1v3 < 0.)
                values[2] = -values[2];
        }
        else if (k == 1) {
            T v1v2 = R (0, 1) + R (1, 0);
            T v2v3 = R (1, 2) + R (2, 1);
            if (v1v2 < 0.)
                values[0] = -values[0];
            if (v2v3 < 0.)
                values[2] = -values[2];
        }
        else if (k == 2) {
            T v1v3 = R (0, 2) + R (2, 0);
            T v2v3 = R (1, 2) + R (2, 1);
            if (v1v3 < 0.)
                values[0] = -values[0];
            if (v2v3 < 0.)
                values[1] = -values[1];
        }
        return values * (static_cast< T > (Pi) - axis.norm2 () / 2);
    }

    return normalize (axis) * angle;
}
}    // namespace

template< class T > EAA< T >::EAA (const Rotation3D< T >& R) : _eaa (angleAxis (R))
{}

template< class T >
EAA< T >::EAA (const rw::math::Vector3D< T >& v1, const rw::math::Vector3D< T >& v2) :
    _eaa (0, 0, 0)
{
    const T epsilon = (T) 1e-15;
    T dval          = rw::math::dot (v1, v2);
    if (fabs (dval - 1) < epsilon) {
        // if the projection is close to 1 then the angle between the vectors are almost 0
        // and we cannot reliably determine the perpendicular axis. A good approximation is
        // therefore just to set the EAA equal to 0.
        _eaa = Vector3D< T > (0, 0, 0);
    }
    else if (fabs (dval + 1) < epsilon) {
        // if the projection is close to -1 then the angle between the vectors are almost
        // 180 and we choose a rotation axis perpendicular to the vector
        int idx = 0;
        if (fabs (v1 (0)) > fabs (v1 (1)))
            idx = 1;
        if (fabs (v1 (idx)) > fabs (v1 (2)))
            idx = 2;
        Vector3D< T > v3 (0, 0, 0);
        v3 (idx) = 1;
        _eaa     = normalize (rw::math::cross (v1, v3)) * (T) Pi;
    }
    else {
        T cosangle = acos (dval);
        _eaa       = normalize (rw::math::cross (v1, v2)) * cosangle;
    }
}

template< class T > const Rotation3D< T > EAA< T >::toRotation3D () const
{
    T theta = angle ();
    T ca    = cos (theta);
    T sa    = sin (theta);
    T va    = 1 - ca;

    Vector3D< T > k = axis ();
    T kx            = k[0];
    T ky            = k[1];
    T kz            = k[2];

    return Rotation3D< T > (kx * kx * va + ca,
                            kx * ky * va - kz * sa,
                            kx * kz * va + ky * sa,
                            kx * ky * va + kz * sa,
                            ky * ky * va + ca,
                            ky * kz * va - kx * sa,
                            kx * kz * va - ky * sa,
                            ky * kz * va + kx * sa,
                            kz * kz * va + ca);
}

// some explicit template specifications
template class rw::math::EAA< double >;
template class rw::math::EAA< float >;

template<>
void rw::common::serialization::write (const EAA< double >& tmp, rw::common::OutputArchive& oar,
                                       const std::string& id)
{
    oar.write (rw::math::Math::toStdVector (tmp, (int) tmp.size ()), id, "EAA");
}
template<>
void rw::common::serialization::read (EAA< double >& tmp, rw::common::InputArchive& iar,
                                      const std::string& id)
{
    std::vector< double > arr;
    iar.read (arr, id, "EAA");
    rw::math::Math::fromStdVector (arr, tmp);
}
template<>
void rw::common::serialization::write (const EAA< float >& tmp, rw::common::OutputArchive& oar,
                                       const std::string& id)
{
    oar.write (rw::math::Math::toStdVector (tmp, (int) tmp.size ()), id, "EAA");
}
template<>
void rw::common::serialization::read (EAA< float >& tmp, rw::common::InputArchive& iar,
                                      const std::string& id)
{
    std::vector< float > arr;
    iar.read (arr, id, "EAA");
    rw::math::Math::fromStdVector (arr, tmp);
}
